// Copyright 2020 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/// Tests for dtype.h

#include "tensorstore/data_type.h"

#include <stddef.h>
#include <stdint.h>

#include <cmath>
#include <memory>
#include <string>
#include <type_traits>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include <nlohmann/json.hpp>
#include "tensorstore/index.h"
#include "tensorstore/internal/elementwise_function.h"
#include "tensorstore/serialization/serialization.h"
#include "tensorstore/serialization/test_util.h"
#include "tensorstore/static_cast.h"
#include "tensorstore/util/status_testutil.h"
#include "tensorstore/util/str_cat.h"

namespace {

using ::tensorstore::AllocateAndConstructShared;
using ::tensorstore::DataType;
using ::tensorstore::DataTypeId;
using ::tensorstore::DataTypeIdOf;
using ::tensorstore::dtype_v;
using ::tensorstore::EqualityComparisonKind;
using ::tensorstore::Index;
using ::tensorstore::IsElementType;
using ::tensorstore::IsTrivial;
using ::tensorstore::StaticDataTypeCast;
using ::tensorstore::StatusIs;
using ::tensorstore::unchecked;
using ::tensorstore::internal::IterationBufferKind;
using ::tensorstore::internal::IterationBufferPointer;
using ::tensorstore::serialization::SerializationRoundTrip;
using ::tensorstore::serialization::TestSerializationRoundTrip;
using ::testing::HasSubstr;

#define X(T, ...)                                           \
  using ::tensorstore::dtypes::T;                           \
  static_assert(static_cast<int>(DataTypeId::T) >= 0 &&     \
                static_cast<int>(DataTypeId::T) <           \
                    static_cast<int>(DataTypeId::num_ids)); \
  /**/

TENSORSTORE_FOR_EACH_DATA_TYPE(X)
#undef X

namespace is_element_type_tests {
struct ClassT {};
union UnionT {};
enum class EnumT {};
static_assert(IsElementType<int>);
static_assert(IsElementType<void>);
static_assert(IsElementType<const void>);
static_assert(IsElementType<const int>);
static_assert(IsElementType<const int*>);
static_assert(IsElementType<const int* const>);
static_assert(!IsElementType<volatile int>);
static_assert(!IsElementType<int[]>);
static_assert(!IsElementType<int(int)>);
static_assert(!IsElementType<int&>);
static_assert(!IsElementType<const int&>);
static_assert(!IsElementType<volatile int&>);
static_assert(!IsElementType<int&&>);
static_assert(IsElementType<ClassT>);
static_assert(IsElementType<UnionT>);
static_assert(IsElementType<EnumT>);
static_assert(IsElementType<int ClassT::*>);
static_assert(IsElementType<int (ClassT::*)(int)>);
}  // namespace is_element_type_tests

TEST(ElementOperationsTest, DataTypeIdOrder) {
  int i = 0;
#define X(T, ...) EXPECT_EQ(i++, static_cast<int>(DataTypeId::T));
  TENSORSTORE_FOR_EACH_DATA_TYPE(X)
#undef X
}

TEST(ElementOperationsTest, UnsignedIntBasic) {
  DataType r = dtype_v<unsigned int>;
  EXPECT_EQ(r->type, typeid(unsigned int));
  EXPECT_EQ(r->size, sizeof(unsigned int));
  EXPECT_EQ(r->alignment, alignof(unsigned int));
}

TEST(ElementOperationsTest, UnsignedIntStaticDynamicConversion) {
  DataType r = dtype_v<unsigned int>;

  // Verify that the conversion succeeds.
  StaticDataTypeCast<unsigned int, unchecked>(r);
  StaticDataTypeCast<unsigned int, unchecked>(dtype_v<unsigned int>);
}

TEST(ElementOperationsTest, UnsignedIntConstruct) {
  DataType r = dtype_v<unsigned int>;

  alignas(alignof(
      unsigned int)) unsigned char dest_char_arr[sizeof(unsigned int) * 5];
  unsigned int* dest_arr = reinterpret_cast<unsigned int*>(&dest_char_arr[0]);
  r->construct(5, dest_arr);
  // Unsigned int constructor doesn't actually do anything.

  r->destroy(5, dest_arr);
  // Unsigned int destructor doesn't actually do anything.
}

TEST(ElementOperationsTest, UnsignedIntCompareEqual) {
  DataType r = dtype_v<unsigned int>;

  unsigned int arr1[5] = {1, 2, 2, 5, 6};
  unsigned int arr2[5] = {1, 2, 3, 4, 6};

  const auto& compare_equal =
      r->compare_equal[static_cast<size_t>(EqualityComparisonKind::equal)]
          .array_array;
  // Call the strided_function variant generated by
  // SimpleElementwiseFunction.
  EXPECT_TRUE(compare_equal[IterationBufferKind::kStrided](
      nullptr, {0, 0},
      IterationBufferPointer{arr1, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
  EXPECT_TRUE(compare_equal[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{arr1, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
  EXPECT_FALSE(compare_equal[IterationBufferKind::kStrided](
      nullptr, {1, 3},
      IterationBufferPointer{arr1, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
}

TEST(ElementOperationsTest, UnsignedIntCompareIdentical) {
  DataType r = dtype_v<unsigned int>;

  unsigned int arr1[5] = {1, 2, 2, 5, 6};
  unsigned int arr2[5] = {1, 2, 3, 4, 6};

  const auto& compare_identical =
      r->compare_equal[static_cast<size_t>(EqualityComparisonKind::identical)]
          .array_array;
  // Call the strided_function variant generated by
  // SimpleElementwiseFunction.
  EXPECT_TRUE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {0, 0},
      IterationBufferPointer{arr1, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
  EXPECT_TRUE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{arr1, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
  EXPECT_FALSE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {1, 3},
      IterationBufferPointer{arr1, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
}

TEST(ElementOperationsTest, UnsignedIntCopyAssign) {
  DataType r = dtype_v<unsigned int>;

  unsigned int source_arr[] = {1, 2, 3, 4, 5};
  unsigned int dest_arr[5] = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
                              0xFFFFFFFF};
  // Call the strided_function variant generated by
  // SimpleElementwiseFunction.
  EXPECT_TRUE(r->copy_assign[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{source_arr, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{dest_arr, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
  EXPECT_THAT(dest_arr,
              ::testing::ElementsAre(1, 3, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF));

  EXPECT_TRUE(r->copy_assign[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{source_arr, Index(0), sizeof(unsigned int)},
      IterationBufferPointer{dest_arr + 1, Index(0), sizeof(unsigned int) * 2},
      /*status=*/nullptr));
  EXPECT_THAT(dest_arr,
              ::testing::ElementsAre(1, 1, 0xFFFFFFFF, 2, 0xFFFFFFFF));

  EXPECT_TRUE(r->copy_assign[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{source_arr, Index(0), sizeof(unsigned int)},
      IterationBufferPointer{dest_arr + 1, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr));
  EXPECT_THAT(dest_arr, ::testing::ElementsAre(1, 1, 2, 2, 0xFFFFFFFF));
}

TEST(ElementOperationsTest, UnsignedIntMoveAssign) {
  DataType r = dtype_v<unsigned int>;

  unsigned int source_arr[] = {1, 2, 3, 4, 5};
  unsigned int dest_arr[5] = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
                              0xFFFFFFFF};
  // Call the strided_function variant generated by
  // SimpleElementwiseFunction.
  r->move_assign[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{source_arr, Index(0), sizeof(unsigned int) * 2},
      IterationBufferPointer{dest_arr, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr);
  EXPECT_THAT(dest_arr,
              ::testing::ElementsAre(1, 3, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF));

  r->move_assign[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{source_arr, Index(0), sizeof(unsigned int)},
      IterationBufferPointer{dest_arr + 1, Index(0), sizeof(unsigned int) * 2},
      /*status=*/nullptr);
  EXPECT_THAT(dest_arr,
              ::testing::ElementsAre(1, 1, 0xFFFFFFFF, 2, 0xFFFFFFFF));

  r->move_assign[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{source_arr, Index(0), sizeof(unsigned int)},
      IterationBufferPointer{dest_arr + 1, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr);
  EXPECT_THAT(dest_arr, ::testing::ElementsAre(1, 1, 2, 2, 0xFFFFFFFF));
}

TEST(ElementOperationsTest, UnsignedIntInitialize) {
  DataType r = dtype_v<unsigned int>;
  unsigned int dest_arr[5] = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
                              0xFFFFFFFF};
  // Call the strided_function variant generated by
  // SimpleElementwiseFunction.
  r->initialize[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{dest_arr, Index(0), sizeof(unsigned int) * 2},
      /*status=*/nullptr);
  EXPECT_THAT(dest_arr,
              ::testing::ElementsAre(0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0xFFFFFFFF));

  r->initialize[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{dest_arr + 3, Index(0), sizeof(unsigned int)},
      /*status=*/nullptr);
  EXPECT_THAT(dest_arr, ::testing::ElementsAre(0, 0xFFFFFFFF, 0, 0, 0));
}

TEST(ElementOperationsTest, UnsignedIntAppendToString) {
  DataType r = dtype_v<unsigned int>;
  std::string s = " ";
  unsigned int value = 5;
  r->append_to_string(&s, &value);
  EXPECT_EQ(" 5", s);
}

TEST(StaticElementRepresentationDeathTest, UnsignedInt) {
  EXPECT_DEBUG_DEATH(
      (StaticDataTypeCast<unsigned int, unchecked>(DataType(dtype_v<float>))),
      "StaticCast is not valid");
}

template <typename T>
void TestCompareIdenticalFloat() {
  const auto compare = [](auto a, auto b) {
    return tensorstore::internal_data_type::CompareIdentical(static_cast<T>(a),
                                                             static_cast<T>(b));
  };

  EXPECT_TRUE(compare(1.0, 1.0));
  EXPECT_FALSE(compare(1.0, 2.0));
  EXPECT_TRUE(compare(+0.0, +0.0));
  EXPECT_TRUE(compare(-0.0, -0.0));
  EXPECT_FALSE(compare(+0.0, -0.0));
  EXPECT_TRUE(compare(NAN, NAN));
  EXPECT_TRUE(compare(INFINITY, INFINITY));
  EXPECT_TRUE(compare(-INFINITY, -INFINITY));
  EXPECT_FALSE(compare(-INFINITY, INFINITY));
  EXPECT_FALSE(compare(NAN, 1));
  EXPECT_FALSE(compare(1, NAN));
}

template <typename T>
void TestCompareIdenticalComplex() {
  using value_type = typename T::value_type;
  const auto compare = [](auto ra, auto ia, auto rb, auto ib) {
    return tensorstore::internal_data_type::CompareIdentical(
        T{static_cast<value_type>(ra), static_cast<value_type>(ia)},
        T{static_cast<value_type>(rb), static_cast<value_type>(ib)});
  };

  EXPECT_TRUE(compare(1.0, 2.0, 1.0, 2.0));
  EXPECT_FALSE(compare(1.0, 2.0, 1.0, 3.0));
  EXPECT_FALSE(compare(1.0, 2.0, 3.0, 2.0));
  EXPECT_TRUE(compare(1.0, +0.0, 1.0, +0.0));
  EXPECT_FALSE(compare(1.0, +0.0, 1.0, -0.0));
  EXPECT_TRUE(compare(1.0, NAN, 1.0, NAN));
  EXPECT_FALSE(compare(1.0, NAN, 2.0, NAN));
  EXPECT_FALSE(compare(1.0, NAN, 1.0, 2.0));
}

TEST(CompareIdenticalTest, Float32) {
  TestCompareIdenticalFloat<tensorstore::dtypes::float32_t>();
}

TEST(CompareIdenticalTest, Float64) {
  TestCompareIdenticalFloat<tensorstore::dtypes::float64_t>();
}

TEST(CompareIdenticalTest, Bfloat16) {
  TestCompareIdenticalFloat<tensorstore::dtypes::bfloat16_t>();
}

TEST(CompareIdenticalTest, Float16) {
  TestCompareIdenticalFloat<tensorstore::dtypes::float16_t>();
}

TEST(CompareIdenticalTest, Complex64) {
  TestCompareIdenticalComplex<tensorstore::dtypes::complex64_t>();
}

TEST(CompareIdenticalTest, Complex128) {
  TestCompareIdenticalComplex<tensorstore::dtypes::complex128_t>();
}

TEST(ElementOperationsTest, FloatCompareIdentical) {
  DataType r = dtype_v<unsigned int>;

  float arr1[5] = {1.0f, -0.0f, NAN, INFINITY, 5.0f};
  float arr2[5] = {1.0f, +0.0f, NAN, INFINITY, 5.0f};
  // Call the strided_function variant generated by
  // SimpleElementwiseFunction.
  const auto& compare_identical =
      r->compare_equal[static_cast<size_t>(EqualityComparisonKind::identical)]
          .array_array;
  EXPECT_TRUE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {1, 1}, IterationBufferPointer{arr1, Index(0), sizeof(float)},
      IterationBufferPointer{arr2, Index(0), sizeof(float)},
      /*status=*/nullptr));

  EXPECT_FALSE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {1, 2}, IterationBufferPointer{arr1, Index(0), sizeof(float)},
      IterationBufferPointer{arr2, Index(0), sizeof(float)},
      /*status=*/nullptr));

  EXPECT_TRUE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {1, 3},
      IterationBufferPointer{arr1, Index(0), sizeof(float) * 2},
      IterationBufferPointer{arr2, Index(0), sizeof(float) * 2},
      /*status=*/nullptr));

  EXPECT_TRUE(compare_identical[IterationBufferKind::kStrided](
      nullptr, {1, 2},
      IterationBufferPointer{arr1, Index(0), sizeof(float) * 3},
      IterationBufferPointer{arr2, Index(0), sizeof(float) * 3},
      /*status=*/nullptr));
}

struct X {
  constexpr static int constant_value = 0;
  std::shared_ptr<const int> value =
      std::shared_ptr<const int>(std::shared_ptr<void>{}, &constant_value);
};

TEST(ElementOperationsTest, Class) {
  DataType r = dtype_v<X>;

  alignas(alignof(X)) unsigned char dest_char_arr[sizeof(X) * 2];
  X* dest_arr = reinterpret_cast<X*>(&dest_char_arr[0]);
  r->construct(2, dest_arr);
  EXPECT_EQ(&X::constant_value, dest_arr[0].value.get());
  EXPECT_EQ(&X::constant_value, dest_arr[1].value.get());
  auto ptr = std::make_shared<int>();
  dest_arr[0].value = ptr;
  dest_arr[1].value = ptr;
  EXPECT_EQ(3, ptr.use_count());
  const auto& compare_equal =
      r->compare_equal[static_cast<size_t>(EqualityComparisonKind::equal)]
          .array_array;

  EXPECT_TRUE(compare_equal[IterationBufferKind::kStrided](
      nullptr, {0, 0}, IterationBufferPointer{dest_arr, Index(0), Index(0)},
      IterationBufferPointer{dest_arr, Index(0), Index(0)},
      /*status=*/nullptr));

  EXPECT_FALSE(compare_equal[IterationBufferKind::kStrided](
      nullptr, {1, 1}, IterationBufferPointer{dest_arr, Index(0), Index(0)},
      IterationBufferPointer{dest_arr, Index(0), Index(0)},
      /*status=*/nullptr));

  r->destroy(2, dest_arr);
  EXPECT_EQ(1, ptr.use_count());
}

TEST(DataTypeTest, Construct) {
  DataType r;
  EXPECT_FALSE(r.valid());
  EXPECT_EQ(DataType(), r);
  r = dtype_v<float>;
  EXPECT_EQ(r, dtype_v<float>);
  EXPECT_TRUE(r.valid());
}

TEST(DataTypeTest, Comparison) {
  EXPECT_TRUE(dtype_v<int> == dtype_v<int>);
  EXPECT_FALSE(dtype_v<int> != dtype_v<int>);
  EXPECT_FALSE(dtype_v<float> == dtype_v<int>);
  EXPECT_TRUE(dtype_v<float> != dtype_v<int>);
  EXPECT_TRUE(DataType(dtype_v<float>) != DataType(dtype_v<int>));
  EXPECT_TRUE(DataType(dtype_v<float>) == typeid(float));
  EXPECT_FALSE(DataType(dtype_v<float>) == typeid(int));
  EXPECT_TRUE(DataType(dtype_v<float>) != typeid(int));
  EXPECT_FALSE(DataType(dtype_v<float>) != typeid(float));
  EXPECT_TRUE(typeid(float) == DataType(dtype_v<float>));
  EXPECT_FALSE(typeid(float) != DataType(dtype_v<float>));
  EXPECT_FALSE(DataType(dtype_v<int>) != DataType(dtype_v<int>));
  EXPECT_TRUE(DataType(dtype_v<int>) == DataType(dtype_v<int>));
}

TEST(AllocateAndConsructSharedTest, Destructor) {
  auto x = std::make_shared<int>();
  {
    auto ptr = AllocateAndConstructShared<std::shared_ptr<int>>(
        1, tensorstore::default_init);
    static_assert(
        std::is_same_v<std::shared_ptr<std::shared_ptr<int>>, decltype(ptr)>);
    ptr.get()[0] = x;
    EXPECT_EQ(2, x.use_count());
  }
  EXPECT_EQ(1, x.use_count());
}

TEST(AllocateAndConsructSharedTest, ValueInitialization) {
  auto ptr = AllocateAndConstructShared<int>(2, tensorstore::value_init);
  EXPECT_EQ(0, ptr.get()[0]);
  EXPECT_EQ(0, ptr.get()[1]);
}

// Thread sanitizer considers `operator new` allocation failure an error, and
// prevents this death test from working.
#if !defined(THREAD_SANITIZER)
TEST(AllocateAndConsructSharedDeathTest, OutOfMemory) {
  const auto allocate = [] {
    AllocateAndConstructShared<int>(0xFFFFFFFFFFFFFFF,
                                    tensorstore::default_init);
  };
#if ABSL_HAVE_EXCEPTIONS
  EXPECT_THROW(allocate(), std::bad_alloc);
#else
  EXPECT_DEATH(allocate(), "");
#endif
}
#endif  // defined(THREAD_SANITIZER)

TEST(DataTypeTest, Name) {
  EXPECT_EQ("bool", DataType(dtype_v<bool_t>).name());
  EXPECT_EQ("byte", DataType(dtype_v<byte_t>).name());
  EXPECT_EQ("char", DataType(dtype_v<char_t>).name());
  EXPECT_EQ("int4", DataType(dtype_v<int4_t>).name());
  EXPECT_EQ("int2", DataType(dtype_v<int2_t>).name());
  // TODO(summivox): b/295577703 uint4
  EXPECT_EQ("int8", DataType(dtype_v<int8_t>).name());
  EXPECT_EQ("uint8", DataType(dtype_v<uint8_t>).name());
  EXPECT_EQ("int16", DataType(dtype_v<int16_t>).name());
  EXPECT_EQ("uint16", DataType(dtype_v<uint16_t>).name());
  EXPECT_EQ("int32", DataType(dtype_v<int32_t>).name());
  EXPECT_EQ("uint32", DataType(dtype_v<uint32_t>).name());
  EXPECT_EQ("int64", DataType(dtype_v<int64_t>).name());
  EXPECT_EQ("uint64", DataType(dtype_v<uint64_t>).name());
  EXPECT_EQ("float16", DataType(dtype_v<float16_t>).name());
  EXPECT_EQ("float32", DataType(dtype_v<float32_t>).name());
  EXPECT_EQ("float64", DataType(dtype_v<float64_t>).name());
  EXPECT_EQ("complex64", DataType(dtype_v<complex64_t>).name());
  EXPECT_EQ("complex128", DataType(dtype_v<complex128_t>).name());
  EXPECT_EQ("string", DataType(dtype_v<string_t>).name());
  EXPECT_EQ("ustring", DataType(dtype_v<ustring_t>).name());
  EXPECT_EQ("json", DataType(dtype_v<json_t>).name());
}

TEST(DataTypeTest, PrintToOstream) {
  EXPECT_EQ("int64", StrCat(dtype_v<int64_t>));
  EXPECT_EQ("<unspecified>", StrCat(DataType()));
}

TEST(DataTypeTest, GetDataType) {
  using ::tensorstore::GetDataType;
  EXPECT_EQ(dtype_v<int4_t>, GetDataType("int4"));
  EXPECT_EQ(dtype_v<int2_t>, GetDataType("int2"));
  // TODO(summivox): b/295577703 uint4
  EXPECT_EQ(dtype_v<int8_t>, GetDataType("int8"));
  EXPECT_EQ(dtype_v<uint8_t>, GetDataType("uint8"));
  EXPECT_EQ(dtype_v<int16_t>, GetDataType("int16"));
  EXPECT_EQ(dtype_v<uint16_t>, GetDataType("uint16"));
  EXPECT_EQ(dtype_v<int32_t>, GetDataType("int32"));
  EXPECT_EQ(dtype_v<uint32_t>, GetDataType("uint32"));
  EXPECT_EQ(dtype_v<int64_t>, GetDataType("int64"));
  EXPECT_EQ(dtype_v<uint64_t>, GetDataType("uint64"));
  EXPECT_EQ(dtype_v<bfloat16_t>, GetDataType("bfloat16"));
  EXPECT_EQ(dtype_v<float16_t>, GetDataType("float16"));
  EXPECT_EQ(dtype_v<float32_t>, GetDataType("float32"));
  EXPECT_EQ(dtype_v<float64_t>, GetDataType("float64"));
  EXPECT_EQ(dtype_v<complex64_t>, GetDataType("complex64"));
  EXPECT_EQ(dtype_v<complex128_t>, GetDataType("complex128"));
  EXPECT_EQ(dtype_v<string_t>, GetDataType("string"));
  EXPECT_EQ(dtype_v<bool_t>, GetDataType("bool"));
  EXPECT_EQ(dtype_v<char_t>, GetDataType("char"));
  EXPECT_EQ(dtype_v<byte_t>, GetDataType("byte"));
  EXPECT_EQ(dtype_v<json_t>, GetDataType("json"));
  EXPECT_EQ(DataType(), GetDataType("foo"));
}

TEST(DataTypeCastTest, Basic) {
  EXPECT_THAT(StaticDataTypeCast<int>(DataType()),
              ::testing::Optional(dtype_v<int>));
  EXPECT_THAT(
      StaticDataTypeCast<int32_t>(DataType(dtype_v<float>)),
      StatusIs(
          absl::StatusCode::kInvalidArgument,
          HasSubstr("Cannot cast data type of float32 to data type of int32")));
}

static_assert(DataTypeIdOf<int> == DataTypeIdOf<int32_t>);
static_assert(DataTypeIdOf<unsigned int> == DataTypeIdOf<uint32_t>);
static_assert(DataTypeIdOf<long long> == DataTypeIdOf<int64_t>);
static_assert(DataTypeIdOf<unsigned long long> == DataTypeIdOf<uint64_t>);
static_assert(sizeof(long) == 4 || sizeof(long) == 8);
static_assert(sizeof(long) != 4 || DataTypeIdOf<long> == DataTypeIdOf<int32_t>);
static_assert(sizeof(long) != 4 ||
              DataTypeIdOf<unsigned long> == DataTypeIdOf<uint32_t>);
static_assert(sizeof(long) == 4 || DataTypeIdOf<long> == DataTypeIdOf<int64_t>);
static_assert(sizeof(long) == 4 ||
              DataTypeIdOf<unsigned long> == DataTypeIdOf<uint64_t>);

TEST(DataTypesOrder, Valid) {
  // Ensure that the kDataTypes is in the same order as the enum.
  for (size_t i = 0; i < tensorstore::kNumDataTypeIds; i++) {
    EXPECT_EQ(tensorstore::kDataTypes[i].id(),
              static_cast<tensorstore::DataTypeId>(i));
  }
}

TEST(SerializationTest, Valid) {
  TestSerializationRoundTrip(DataType());
  for (DataType dtype : tensorstore::kDataTypes) {
    TestSerializationRoundTrip(dtype);
  }
}

TEST(SerializationTest, Invalid) {
  EXPECT_THAT(SerializationRoundTrip(DataType(dtype_v<X>)),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       HasSubstr("Cannot serialize custom data type:")));
}

static_assert(IsTrivial<bool>);
static_assert(IsTrivial<int2_t>);
static_assert(IsTrivial<int4_t>);
static_assert(IsTrivial<bfloat16_t>);
static_assert(IsTrivial<float16_t>);
static_assert(IsTrivial<float32_t>);
static_assert(IsTrivial<float64_t>);
static_assert(IsTrivial<uint64_t>);
static_assert(IsTrivial<complex64_t>);
static_assert(IsTrivial<complex128_t>);
static_assert(!IsTrivial<string_t>);
static_assert(!IsTrivial<ustring_t>);
static_assert(!IsTrivial<json_t>);

}  // namespace
